package com.yeziji.config.aspect;

import cn.hutool.core.util.StrUtil;
import cn.hutool.crypto.digest.MD5;
import com.alibaba.fastjson2.JSONObject;
import com.yeziji.annotation.ApiLock;
import com.yeziji.common.CommonResult;
import com.yeziji.common.base.ApiOverclockDegradationHandlerBase;
import com.yeziji.common.context.OnlineContext;
import com.yeziji.constant.ApiLockMonitorMode;
import com.yeziji.service.cache.RedisService;
import com.yeziji.utils.DateUtils;
import com.yeziji.utils.ServletUtils;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.CodeSignature;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.StringJoiner;
import java.util.concurrent.TimeUnit;

/**
 * api 锁切面
 *
 * @author hwy
 * @since 2024/02/02 21:01
 **/
@Slf4j
@Aspect
@Component
public class ApiLockAspect {
    @Resource
    private RedisService redisService;
    @Resource
    private HttpServletRequest request;
    @Resource
    private HttpServletResponse response;

    @Around("@within(apiLock) || @annotation(apiLock)")
    public Object apiLockHandler(ProceedingJoinPoint pjp, ApiLock apiLock) throws Throwable {
        // 没有监听
        ApiLockMonitorMode[] monitors = apiLock.monitors();
        String monitorKey = this.getMonitorKey(pjp, monitors);
        if (StrUtil.isNotBlank(monitorKey)) {
            // 先判断用户是否被封禁
            String banKey = monitorKey + ":ban";
            long banTime = redisService.getExpireTime(banKey);
            if (banTime > 0L) {
                log.warn("{} 封禁剩余时间: {}", monitorKey, DateUtils.formatSecondTime(banTime));
                return Optional.ofNullable(this.downgradeHandler("ban", apiLock)).orElse(CommonResult.failed("system-api-lock-error.access-denied"));
            }
            // 然后判断是否在间隔内反复请求
            Object obj = redisService.get(monitorKey);
            if (obj == null) {
                // 设置缓存
                redisService.set(monitorKey, true, apiLock.acceptInterval(), TimeUnit.MILLISECONDS);
            } else {
                log.warn("{} 间隔内超频访问, time: {}ms", monitorKey, apiLock.acceptInterval());
                long ban = apiLock.banTime();
                if (ban > 0L) {
                    log.warn("ban {}, time: {}ms", monitorKey, ban);
                    redisService.set(monitorKey + ":ban", ban, ban, TimeUnit.MILLISECONDS);
                }
                return Optional.ofNullable(this.downgradeHandler("interval", apiLock)).orElse(CommonResult.failed("system-api-lock-error.dont-visit-frequently"));
            }
        }
        return pjp.proceed();
    }

    /**
     * 获取监听 key
     *
     * @param pjp 切面入参
     * @param monitors 监听模式
     * @return {@link String} 最后的 key 值
     */
    public String getMonitorKey(ProceedingJoinPoint pjp, ApiLockMonitorMode[] monitors) {
        String uri = request.getRequestURI();
        StringJoiner keyJoiner = new StringJoiner(":", String.format("api_lock_monitor:%s:", uri), "");
        for (ApiLockMonitorMode monitor : monitors) {
            if (monitor == ApiLockMonitorMode.NONE) {
                return "";
            }
            // 获取 ip
            if (monitor == ApiLockMonitorMode.IP) {
                String ipAddress = ServletUtils.getIpAddress(request);
                if (StrUtil.isNotBlank(ipAddress)) {
                    keyJoiner.add(ipAddress);
                }
            }
            // 捕获 userId
            if (monitor == ApiLockMonitorMode.USER_ID) {
                if (OnlineContext.isHasSession()) {
                    Optional.ofNullable(OnlineContext.getUserId()).map(String::valueOf).ifPresent(keyJoiner::add);
                }
            }
            // 监听入参
            if (monitor == ApiLockMonitorMode.PARAMS) {
                String param = this.getParam(pjp);
                if (StrUtil.isNotBlank(param)) {
                    String paramsMd5 = MD5.create().digestHex(param);
                    keyJoiner.add(paramsMd5);
                }
            }
        }
        return keyJoiner.toString();
    }

    /**
     * 获取 aspect 切入点 params
     *
     * @param proceedingJoinPoint 切入点
     * @return {@link String} jsonStr
     */
    public String getParam(ProceedingJoinPoint proceedingJoinPoint) {
        Map<String, Object> map = new HashMap<>();
        Object[] values = proceedingJoinPoint.getArgs();
        String[] names = ((CodeSignature) proceedingJoinPoint.getSignature()).getParameterNames();
        for (int i = 0; i < names.length; i++) {
            map.put(names[i], values[i]);
        }
        return JSONObject.toJSONString(map);
    }

    /**
     * 降级行为
     *
     * @param methodName 降级方法
     * @param apiLock apiLock
     * @return {@link Object} 降级方法返回的结果
     *
     * @throws NoSuchMethodException 找不到指定的降级方法
     * @throws InvocationTargetException 实例化构造失败
     * @throws IllegalAccessException 实例化构造失败
     * @throws InstantiationException 实例化构造失败
     */
    public Object downgradeHandler(String methodName, ApiLock apiLock) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException, InstantiationException {
        // 执行 handler
        Class<? extends ApiOverclockDegradationHandlerBase> downgrade = apiLock.downgrade();
        if (downgrade != null) {
            Method method = downgrade.getMethod(methodName, HttpServletRequest.class, HttpServletResponse.class);
            ApiOverclockDegradationHandlerBase handler = downgrade.getDeclaredConstructor().newInstance();
            return method.invoke(handler, request, response);
        }
        return null;
    }
}
